vllm fakequant reload with modelopt state for HF#805
vllm fakequant reload with modelopt state for HF#805kinjalpatel27 wants to merge 30 commits intomainfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds end-to-end support for reloading vLLM fakequant state from ModelOpt HF exports: new HF PTQ exporter, vLLM conversion/sharding utilities, worker-side loading paths for Changes
Sequence Diagram(s)sequenceDiagram
participant Worker as vLLM Worker
participant FS as File System
participant Converter as vLLM Reload Utils
participant Model as Model
Worker->>FS: read MODELOPT_STATE_PATH / QUANT_FILE_PATH
FS-->>Worker: modelopt_state / quant_state
Worker->>Converter: convert_modelopt_state_to_vllm(modelopt_state)
Converter-->>Worker: vLLM-formatted state
Worker->>Converter: filter_modelopt_state_quantizer_state_for_model(vLLM_state, Model)
Converter-->>Worker: filtered state aligned to Model keys
Worker->>Converter: process_state_dict_for_tp(filtered_state, Model.state_dict())
Converter-->>Worker: TP-sharded state
Worker->>Model: restore_from_modelopt_state_vllm(Model, TP-sharded state)
Model-->>Worker: model restored with quantizer state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #805 +/- ##
==========================================
+ Coverage 70.09% 70.11% +0.02%
==========================================
Files 221 221
Lines 25459 25459
==========================================
+ Hits 17845 17851 +6
+ Misses 7614 7608 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/nn/modules/quant_module.py`:
- Around line 62-72: The fallback block currently unconditionally overwrites
non_tq_param_or_buffer; change the guard to only run when non_tq_param_or_buffer
is None (e.g., if non_tq_param_or_buffer is None and model is not None) so the
first-found parameter isn't clobbered, and when computing the parent module use
model.get_submodule(parent_prefix) if parent_prefix else model to avoid calling
get_submodule with an empty string; also either implement the intended filtering
to skip TensorQuantizer-owned params (add a check inside the loop to continue if
the param belongs to a TensorQuantizer) or update the comment to reflect that
the code simply takes the first parameter, referencing non_tq_param_or_buffer,
prefix, and model.get_submodule to locate where to change.
🧹 Nitpick comments (5)
modelopt/torch/export/plugins/vllm_fakequant_hf.py (1)
33-43: Docstring is outdated and references amax instead of quantizer state.The docstring still describes extracting "amax values" but the implementation now saves the complete quantizer state dict and modelopt state. Update to reflect the actual behavior.
📝 Suggested docstring update
- """Exports the torch model weights and amax values separately. + """Exports the torch model weights and quantizer state separately for vLLM fakequant. This function: - 1. Extracts amax values for calibration + 1. Extracts quantizer state dict and modelopt state 2. Deletes all quantizer parameters from state dict to store only weights in original dtype 3. Saves the model weights Args: model: The quantized model to export - export_dir: Directory to save the amax values + export_dir: Directory to save the model and quantizer state """modelopt/torch/export/plugins/vllm_fakequant_megatron.py (1)
46-64: Comments reference "amax" but code now handles full quantizer state.Several comments still reference "amax" (lines 46, 51, 63) but the code now handles the complete quantizer state dictionary. Consider updating for clarity.
📝 Suggested comment updates
- # Gather all amax dicts to rank 0 + # Gather all quantizer state dicts to rank 0 world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() if rank == 0: - # Rank 0 will collect all amax values + # Rank 0 will collect all quantizer state values all_quantizer_state_dicts = [None] * world_size torch.distributed.gather_object(quantizer_state_dict, all_quantizer_state_dicts, dst=0) ... else: - # Other ranks just send their amax values + # Other ranks send their quantizer state values torch.distributed.gather_object(quantizer_state_dict, None, dst=0)examples/vllm_serve/vllm_reload_utils.py (1)
175-185: Docstring references non-existent parameterfuse_experts.The docstring mentions
fuse_expertsparameter but the function only hasstate_dictandmerge_modeparameters.📝 Suggested docstring fix
""" Common implementation for converting quantizer state from HF to vLLM format. Args: state_dict: Input state dict - fuse_experts: Whether to fuse expert projections merge_mode: Mode to merge grouped values, "max_or_concat" or "require_identical" + + Returns: + Converted state dict in vLLM format. """examples/vllm_serve/fakequant_worker.py (2)
115-124: Consider adding a comment aboutweights_only=Falsesecurity implications.Using
weights_only=Falseintorch.loadis necessary for loading complex modelopt state, but it allows arbitrary code execution from untrusted files. The current code is fine for trusted checkpoints, but a brief comment noting this would be helpful for future maintainers.💡 Optional: Add security note
# Load on CPU to avoid failures when the checkpoint was saved from a different # GPU mapping + # Note: weights_only=False is required for modelopt state but should only be used + # with trusted checkpoint files. modelopt_state = torch.load( quant_config["modelopt_state_path"], weights_only=False, map_location="cpu" )
235-242: Asymmetric key validation is intentional but could use a comment.The code raises an error when model keys are missing from the checkpoint but only warns when checkpoint has extra keys. This asymmetry makes sense (model requires all its quantizers to be loaded), but a brief comment explaining the rationale would improve clarity.
💡 Optional: Add clarifying comment
+ # Checkpoint may have extra keys (e.g., from PP sharding), but model must have + # all its quantizer keys present in the checkpoint for correct loading for key in checkpoint_quant_keys: if key not in model_quant_keys: print(f"Key {key} not found in model state dict, but exists in checkpoint") for key in model_quant_keys: if key not in checkpoint_quant_keys: raise ValueError( f"Key {key} not found in checkpoint state dict, but exists in model" )
4383172 to
ef06919
Compare
12f7fe9 to
4f3d698
Compare
|
@kinjalpatel27 A quick question here: HF exported checkpoint is real quantized checkpoint, why we want to load it in vLLM fakequant? Is it for evaluation purpose (in case the vllm deployment is not supported)? |
@Edwardf0t1 HF exported checkpoint is for vLLM fakequant ( It exports BF16 weights and quantizer states which can be reloaded in vLLM fakequant serve script (https://github.com/NVIDIA/Model-Optimizer/blob/4f3d698e332d1d746c3b08a0e9cbc98dc2c67b84/examples/vllm_serve/fakequant_worker.py#L111:L124) |
220f205 to
9b42a09
Compare
0f20188 to
4bcc9f1
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/vllm_serve/fakequant_worker.py (1)
103-106:⚠️ Potential issue | 🔴 CriticalCRITICAL:
trust_remote_code=Trueis hardcoded.Per coding guidelines,
trust_remote_codemust be exposed as a caller-configurable parameter defaulting toFalse, not hardcoded toTrue. This is a security risk as it allows execution of arbitrary code from model repositories.🔧 Proposed fix - make trust_remote_code configurable via environment variable
tokenizer = AutoTokenizer.from_pretrained( self.model_runner.model_config.tokenizer, - trust_remote_code=True, + trust_remote_code=self.model_runner.model_config.trust_remote_code, )Alternatively, if the model_config doesn't expose this, add an environment variable:
+ trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true" tokenizer = AutoTokenizer.from_pretrained( self.model_runner.model_config.tokenizer, - trust_remote_code=True, + trust_remote_code=trust_remote_code, )As per coding guidelines: "trust_remote_code must be exposed as a caller-configurable parameter defaulting to False for transformers model or tokenizer loading, not hardcoded to True"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` around lines 103 - 106, The call to AutoTokenizer.from_pretrained hardcodes trust_remote_code=True which is a security risk; change the code so trust_remote_code is caller-configurable and defaults to False by reading a new parameter (e.g., add a boolean flag on the worker or model runner config, or fallback to an env var) and pass that flag to AutoTokenizer.from_pretrained instead of True; update references around tokenizer and self.model_runner.model_config.tokenizer to use the new flag (e.g., trust_remote_code = self.model_runner.model_config.trust_remote_code or os.getenv(..., "false") ) so callers can opt-in while preserving a safe default.
🧹 Nitpick comments (3)
examples/vllm_serve/hf_ptq_export.py (1)
41-41: Consider movingenable_huggingface_checkpointing()insidemain().Calling
mto.enable_huggingface_checkpointing()at module level means it executes whenever this module is imported, which could cause unexpected side effects if the module is imported for other purposes (e.g., testing individual functions).♻️ Proposed fix
-mto.enable_huggingface_checkpointing() - - def load_model( ... + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + # Enable ModelOpt save/restore for HuggingFace models + mto.enable_huggingface_checkpointing() + + random.seed(RAND_SEED)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/hf_ptq_export.py` at line 41, The call to mto.enable_huggingface_checkpointing() is currently at module import time and should be moved inside main() to avoid side effects on import; locate the module-level invocation of enable_huggingface_checkpointing(), remove it from top-level, and add a call at the start of the main() function (before any model loading or IO) so checkpointing is enabled only when main() runs.modelopt/torch/quantization/plugins/custom.py (1)
117-117: Consider forwarding**kwargstosuper().modelopt_post_restore().The signature now accepts
*args, **kwargsfor compatibility, but thesuper()call at line 174 only passesprefix=prefix. If a caller passesmodel=..., it won't be forwarded to the base classQuantModule.modelopt_post_restore(), which now supports themodelparameter for device discovery.♻️ Proposed fix
# If there are any other states, lets move them to the correct device - super().modelopt_post_restore(prefix=prefix) + super().modelopt_post_restore(prefix=prefix, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/plugins/custom.py` at line 117, The override modelopt_post_restore(self, prefix: str = "", *args, **kwargs) must forward any additional keyword args to the base implementation; update the call to super().modelopt_post_restore(...) (the QuantModule.modelopt_post_restore path) to pass prefix=prefix and also **kwargs (and *args if present) so parameters such as model=... used for device discovery are propagated to the parent implementation.modelopt/torch/export/plugins/vllm_fakequant_megatron.py (1)
55-60: Variable shadowing in loop.The loop variable
quantizer_state_dictshadows the outer variable of the same name defined at line 42. This makes the code harder to read and could lead to confusion.♻️ Proposed fix
# Merge all quantizer state dicts into one merged_quantizer_state_dict = {} - for quantizer_state_dict in all_quantizer_state_dicts: - if quantizer_state_dict is not None: - merged_quantizer_state_dict.update(quantizer_state_dict) + for qstate_dict in all_quantizer_state_dicts: + if qstate_dict is not None: + merged_quantizer_state_dict.update(qstate_dict)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/vllm_fakequant_megatron.py` around lines 55 - 60, The loop is shadowing the outer variable quantizer_state_dict; rename the loop variable (for example to iter_quantizer_state or q_state_dict) in the merge block that iterates over all_quantizer_state_dicts and update any references in that loop to use the new name so the outer quantizer_state_dict remains intact; ensure merged_quantizer_state_dict.update(...) uses the renamed loop variable and run tests to validate no other references rely on the old name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 115-117: The torch.load call that sets modelopt_state
(torch.load(quant_config["modelopt_state_path"], weights_only=False,
map_location="cpu")) must include an inline comment justifying why using
weights_only=False is safe; update the line near modelopt_state assignment to
add a brief inline comment stating that the file at
quant_config["modelopt_state_path"] is generated internally (not user-supplied),
comes from a trusted build/release, or is validated beforehand so arbitrary code
execution is not a concern.
- Line 219: The torch.load call that assigns saved_quant_dict from
quantizer_file_path should include the weights_only parameter to avoid unsafe
deserialization; update the torch.load(...) call (the one creating
saved_quant_dict) to use weights_only=True if the checkpoint contains only
tensors, or set weights_only=False and add an inline comment justifying why
non-tensor objects are safe to load per guidelines.
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 180-184: The map_fun parameter on convert_dict_to_vllm is
mis-typed and mis-defaulted: it is actually invoked with a single dict and
should accept/return a dict (see call at line where map_fun(state_dict) and
usage in fakequant_worker.py referencing hf_to_vllm_mapper.apply_dict). Update
the type hint from Callable[[str, Any], tuple[str, Any]] to Callable[[dict[str,
Any]], dict[str, Any]] and change the default from lambda x: x to a dict
identity like lambda d: d so convert_dict_to_vllm(state_dict, ..., map_fun) and
any callers expecting apply_dict match correctly; ensure any internal code
treats map_fun output as a dict.
In `@modelopt/torch/quantization/nn/modules/quant_module.py`:
- Around line 120-126: The loop that checks parent_module.named_parameters uses
param_parent_name and calls parent_module.get_submodule(param_parent_name),
which can misbehave when param_parent_name is empty; update the logic in the
loop (around param_parent_name and param_parent) to fall back to using
parent_module itself when param_parent_name == "" (i.e., set param_parent =
parent_module) before testing isinstance(param_parent, TensorQuantizer) so the
check for TensorQuantizer and assignment to non_tq_param_or_buffer behaves
correctly across PyTorch versions.
---
Outside diff comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 103-106: The call to AutoTokenizer.from_pretrained hardcodes
trust_remote_code=True which is a security risk; change the code so
trust_remote_code is caller-configurable and defaults to False by reading a new
parameter (e.g., add a boolean flag on the worker or model runner config, or
fallback to an env var) and pass that flag to AutoTokenizer.from_pretrained
instead of True; update references around tokenizer and
self.model_runner.model_config.tokenizer to use the new flag (e.g.,
trust_remote_code = self.model_runner.model_config.trust_remote_code or
os.getenv(..., "false") ) so callers can opt-in while preserving a safe default.
---
Nitpick comments:
In `@examples/vllm_serve/hf_ptq_export.py`:
- Line 41: The call to mto.enable_huggingface_checkpointing() is currently at
module import time and should be moved inside main() to avoid side effects on
import; locate the module-level invocation of
enable_huggingface_checkpointing(), remove it from top-level, and add a call at
the start of the main() function (before any model loading or IO) so
checkpointing is enabled only when main() runs.
In `@modelopt/torch/export/plugins/vllm_fakequant_megatron.py`:
- Around line 55-60: The loop is shadowing the outer variable
quantizer_state_dict; rename the loop variable (for example to
iter_quantizer_state or q_state_dict) in the merge block that iterates over
all_quantizer_state_dicts and update any references in that loop to use the new
name so the outer quantizer_state_dict remains intact; ensure
merged_quantizer_state_dict.update(...) uses the renamed loop variable and run
tests to validate no other references rely on the old name.
In `@modelopt/torch/quantization/plugins/custom.py`:
- Line 117: The override modelopt_post_restore(self, prefix: str = "", *args,
**kwargs) must forward any additional keyword args to the base implementation;
update the call to super().modelopt_post_restore(...) (the
QuantModule.modelopt_post_restore path) to pass prefix=prefix and also **kwargs
(and *args if present) so parameters such as model=... used for device discovery
are propagated to the parent implementation.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
CHANGELOG.rstexamples/vllm_serve/README.mdexamples/vllm_serve/fakequant_worker.pyexamples/vllm_serve/hf_ptq_export.pyexamples/vllm_serve/vllm_reload_utils.pyexamples/vllm_serve/vllm_serve_fakequant.pymodelopt/torch/export/plugins/vllm_fakequant_hf.pymodelopt/torch/export/plugins/vllm_fakequant_megatron.pymodelopt/torch/quantization/conversion.pymodelopt/torch/quantization/nn/modules/quant_module.pymodelopt/torch/quantization/plugins/custom.pymodelopt/torch/quantization/plugins/megatron.pymodelopt/torch/quantization/plugins/transformer_engine.py
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/export/plugins/vllm_fakequant_hf.py
- CHANGELOG.rst
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/vllm_serve/fakequant_worker.py (1)
103-106:⚠️ Potential issue | 🔴 CriticalDo not hardcode
trust_remote_code=Truefor tokenizer loading.Per security guidelines,
trust_remote_code=Truemust be caller-configurable and default toFalse. Hardcoding it opens an unnecessary remote code execution surface.Consider making this configurable via an environment variable or function parameter:
🔧 Proposed fix
+ trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", "0") == "1" tokenizer = AutoTokenizer.from_pretrained( self.model_runner.model_config.tokenizer, - trust_remote_code=True, + trust_remote_code=trust_remote_code, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` around lines 103 - 106, The code hardcodes trust_remote_code=True when calling AutoTokenizer.from_pretrained; change this to be caller-configurable and default to False by reading a configuration flag (e.g., an environment variable or a constructor/runner parameter) instead of hardcoding. Update the call in fakequant_worker.py where AutoTokenizer.from_pretrained is invoked (refer to tokenizer and self.model_runner.model_config.tokenizer) to pass trust_remote_code=(config_flag) where config_flag defaults to False and can be set by the caller or via an env var like TRUST_REMOTE_CODE; ensure the new flag is documented where the worker/ModelRunner is constructed.
♻️ Duplicate comments (2)
examples/vllm_serve/fakequant_worker.py (2)
115-117:⚠️ Potential issue | 🔴 Critical
weights_only=Falseload still lacks the required safety justification.This deserializes arbitrary pickled objects; add an inline safety justification (trusted/internal artifact) at the load site.
🔧 Proposed fix
+ # Safe: this file is an internally generated ModelOpt artifact from a trusted pipeline. modelopt_state = torch.load( quant_config["modelopt_state_path"], weights_only=False, map_location="cpu" )As per coding guidelines,
torch.load(..., weights_only=False)requires an inline comment justifying why it is safe.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` around lines 115 - 117, The torch.load call that sets modelopt_state using quant_config["modelopt_state_path"] currently passes weights_only=False without an inline safety justification; update the call site (the torch.load(...) that assigns modelopt_state) to include a brief inline comment explaining why using weights_only=False is safe (e.g., the file is a trusted/internal artifact produced by our pipeline and contains no untrusted pickles), referencing the trusted provenance and any validation performed.
223-223:⚠️ Potential issue | 🟠 MajorSet
weights_onlyexplicitly when loading quantizer files.Leaving this implicit makes behavior runtime-version dependent and can allow unsafe object deserialization on older PyTorch versions.
🔧 Proposed fix
- saved_quant_dict = torch.load(quantizer_file_path, map_location="cpu") + saved_quant_dict = torch.load( + quantizer_file_path, map_location="cpu", weights_only=True + )As per coding guidelines, unsafe
torch.loadusage must be explicitly constrained/justified rather than relying on implicit defaults.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` at line 223, The torch.load call that reads the quantizer file (saved_quant_dict = torch.load(quantizer_file_path, map_location="cpu")) must explicitly pass weights_only=True to avoid unsafe object deserialization and version-dependent behavior; update the load invocation to include weights_only=True and, if supporting older PyTorch versions, wrap the call in a try/except that falls back to a safe alternative or raises a clear error indicating the required minimum torch version—ensure references to quantizer_file_path and saved_quant_dict are retained so the change is applied to the correct load site.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 229-233: The comprehension filtering saved_quant_dict uses
endswith("quantizer_") which will drop most quantizer entries; update the filter
to check for the substring "quantizer_" (e.g., if "quantizer_" in key) and
perform a targeted replacement (use key.replace("quantizer_", "quantizer._", 1))
so all keys containing quantizer_ are remapped into saved_quant_dict; modify the
dict comprehension that builds saved_quant_dict accordingly.
- Around line 143-151: The calibration loop in calibrate_loop currently only
uses batch["input_ids"][0], ignoring calib_batch_size and dropping samples;
update the loop to iterate over all samples in each batch (e.g., for each sample
in batch["input_ids"]) or slice the tensor to calib_batch_size, convert each
sample tensor to a list for vLLM compatibility, and feed all samples to the
calibration logic so every item in calib_dataloader batch is used; locate
calibrate_loop, calib_dataloader, and places referencing
input_ids/input_ids_list to implement this change.
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 199-203: The code silently defaults to
_merge_values_by_max_or_concat when merge_mode is mistyped; add explicit
validation for merge_mode before selecting merge_fn: check that merge_mode is
one of the allowed values ("require_identical" or "max_or_concat"), raise a
ValueError with a clear message if not, and then set merge_fn to
_merge_values_require_identical when merge_mode == "require_identical" else
_merge_values_by_max_or_concat; reference the symbols merge_mode, merge_fn,
_merge_values_require_identical, and _merge_values_by_max_or_concat when
modifying the selection logic.
---
Outside diff comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 103-106: The code hardcodes trust_remote_code=True when calling
AutoTokenizer.from_pretrained; change this to be caller-configurable and default
to False by reading a configuration flag (e.g., an environment variable or a
constructor/runner parameter) instead of hardcoding. Update the call in
fakequant_worker.py where AutoTokenizer.from_pretrained is invoked (refer to
tokenizer and self.model_runner.model_config.tokenizer) to pass
trust_remote_code=(config_flag) where config_flag defaults to False and can be
set by the caller or via an env var like TRUST_REMOTE_CODE; ensure the new flag
is documented where the worker/ModelRunner is constructed.
---
Duplicate comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 115-117: The torch.load call that sets modelopt_state using
quant_config["modelopt_state_path"] currently passes weights_only=False without
an inline safety justification; update the call site (the torch.load(...) that
assigns modelopt_state) to include a brief inline comment explaining why using
weights_only=False is safe (e.g., the file is a trusted/internal artifact
produced by our pipeline and contains no untrusted pickles), referencing the
trusted provenance and any validation performed.
- Line 223: The torch.load call that reads the quantizer file (saved_quant_dict
= torch.load(quantizer_file_path, map_location="cpu")) must explicitly pass
weights_only=True to avoid unsafe object deserialization and version-dependent
behavior; update the load invocation to include weights_only=True and, if
supporting older PyTorch versions, wrap the call in a try/except that falls back
to a safe alternative or raises a clear error indicating the required minimum
torch version—ensure references to quantizer_file_path and saved_quant_dict are
retained so the change is applied to the correct load site.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 32c569c1-d958-470a-a629-b357a57d4fa6
📒 Files selected for processing (2)
examples/vllm_serve/fakequant_worker.pyexamples/vllm_serve/vllm_reload_utils.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
examples/vllm_serve/vllm_reload_utils.py (1)
200-204:⚠️ Potential issue | 🟡 MinorValidate
merge_modeexplicitly; don’t silently fallback.A typo currently falls through to
"max_or_concat"and can apply the wrong merge strategy without surfacing an error.Proposed fix
- merge_fn = ( - _merge_values_require_identical - if merge_mode == "require_identical" - else _merge_values_by_max_or_concat - ) + if merge_mode == "require_identical": + merge_fn = _merge_values_require_identical + elif merge_mode == "max_or_concat": + merge_fn = _merge_values_by_max_or_concat + else: + raise ValueError(f"Unsupported merge_mode: {merge_mode}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/vllm_reload_utils.py` around lines 200 - 204, The current selection of merge_fn silently falls back to _merge_values_by_max_or_concat when merge_mode is misspelled; change the logic that assigns merge_fn (the conditional using merge_mode) to explicitly validate merge_mode and only accept known values ("require_identical" and "max_or_concat"), selecting _merge_values_require_identical or _merge_values_by_max_or_concat accordingly and raising a clear ValueError (or similar) if merge_mode is invalid so typos don’t silently change behavior.
🧹 Nitpick comments (1)
examples/vllm_serve/vllm_reload_utils.py (1)
195-196: Remove stalefuse_expertsdocstring entry.
convert_dict_to_vllmhas nofuse_expertsparameter, so this is now misleading API documentation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/vllm_reload_utils.py` around lines 195 - 196, The docstring for convert_dict_to_vllm incorrectly documents a non-existent parameter "fuse_experts"; remove the stale "fuse_experts: Whether to fuse expert projections" entry from the function/class docstring (search for convert_dict_to_vllm in vllm_reload_utils.py) and update the remaining parameter list so it only documents actual parameters (e.g., keep/verify merge_mode text matches the function signature).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 342-350: Before computing shard_dims, guard that the checkpoint
entry `value` actually has a shape with the expected rank/type: check that
hasattr(value, "shape") and that `value.shape` is indexable (e.g., tuple/list)
and len(value.shape) >= len(expected_shape); if not, raise a clear ValueError
mentioning the tensor key and shapes. Then compute `shard_dims` by iterating
only over dimensions that exist in `value.shape` (avoid directly indexing inside
the list comprehension) and compare `value.shape[d] == expected_shape[d] *
tp_world_size`; this prevents AttributeError/IndexError when `value` is not a
tensor of the expected rank before the intended mismatch handling for
`value.shape != expected_shape`.
---
Duplicate comments:
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 200-204: The current selection of merge_fn silently falls back to
_merge_values_by_max_or_concat when merge_mode is misspelled; change the logic
that assigns merge_fn (the conditional using merge_mode) to explicitly validate
merge_mode and only accept known values ("require_identical" and
"max_or_concat"), selecting _merge_values_require_identical or
_merge_values_by_max_or_concat accordingly and raising a clear ValueError (or
similar) if merge_mode is invalid so typos don’t silently change behavior.
---
Nitpick comments:
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 195-196: The docstring for convert_dict_to_vllm incorrectly
documents a non-existent parameter "fuse_experts"; remove the stale
"fuse_experts: Whether to fuse expert projections" entry from the function/class
docstring (search for convert_dict_to_vllm in vllm_reload_utils.py) and update
the remaining parameter list so it only documents actual parameters (e.g.,
keep/verify merge_mode text matches the function signature).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 79b2b0fd-34ec-4b46-ba69-75d992d95b41
📒 Files selected for processing (1)
examples/vllm_serve/vllm_reload_utils.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/vllm_serve/fakequant_worker.py (1)
103-106:⚠️ Potential issue | 🔴 CriticalCRITICAL:
trust_remote_code=Trueis hardcoded.Per coding guidelines,
trust_remote_codemust be exposed as a caller-configurable parameter defaulting toFalse, not hardcoded toTrue. This flag allows execution of arbitrary Python shipped with a checkpoint, which is an RCE vector if the model source is untrusted.🔧 Proposed fix - expose as configurable parameter
quant_config: dict[str, Any] = { "dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"), "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), "quant_file_path": os.environ.get("QUANT_FILE_PATH", None), "modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None), "calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)), + "trust_remote_code": os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true", }Then update the tokenizer loading:
tokenizer = AutoTokenizer.from_pretrained( self.model_runner.model_config.tokenizer, - trust_remote_code=True, + trust_remote_code=quant_config["trust_remote_code"], )As per coding guidelines: "trust_remote_code parameter for transformers model or tokenizer loading must be exposed as a caller-configurable parameter defaulting to False, not hardcoded to True"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` around lines 103 - 106, The tokenizer call currently hardcodes trust_remote_code=True; change this to use a caller-configurable boolean (default False) exposed on the worker/config that constructs the model runner (e.g., add a parameter like trust_remote_code=False to the FakeQuantWorker constructor or its config), then pass that value into AutoTokenizer.from_pretrained(...) instead of True, keeping the same tokenizer key self.model_runner.model_config.tokenizer; update any instantiation sites to accept the new parameter or use the default.
🧹 Nitpick comments (2)
modelopt/torch/quantization/nn/modules/quant_module.py (1)
119-130: Consider a buffer fallback when parent has no parameters.Line 120 only inspects
named_parameters(). For buffer-only parent modules, device detection still fails even when a valid non-TensorQuantizerbuffer exists. Adding anamed_buffers()fallback would make restore more robust.♻️ Proposed refactor
# Look for any parameter in parent module (not just state_dict) for name, param in parent_module.named_parameters(): # Skip params that belong to TensorQuantizer submodules param_parent_name = name.rsplit(".", 1)[0] if "." in name else "" param_parent = ( parent_module.get_submodule(param_parent_name) if param_parent_name else parent_module ) if not isinstance(param_parent, TensorQuantizer): non_tq_param_or_buffer = param break + + # Fallback: some modules may be buffer-only + if non_tq_param_or_buffer is None: + for name, buf in parent_module.named_buffers(): + buf_parent_name = name.rsplit(".", 1)[0] if "." in name else "" + buf_parent = ( + parent_module.get_submodule(buf_parent_name) + if buf_parent_name + else parent_module + ) + if not isinstance(buf_parent, TensorQuantizer): + non_tq_param_or_buffer = buf + break🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/nn/modules/quant_module.py` around lines 119 - 130, The parent-module device detection loop in quant_module.py currently only checks parent_module.named_parameters() and thus misses modules that expose only buffers; update the logic in the restore/device-detection area (the loop that iterates parent_module.named_parameters() and sets non_tq_param_or_buffer) to also fall back to parent_module.named_buffers() when no suitable parameter is found, applying the same TensorQuantizer parent check (using param_parent_name and parent_module.get_submodule(...)) so a non-TensorQuantizer buffer can be selected; ensure the code still prefers parameters over buffers but uses a buffer if no parameter exists.examples/vllm_serve/vllm_reload_utils.py (1)
140-163: Consider adding defensive check for emptykey_value_pairs.If
key_value_pairsis empty,values[0]will raise anIndexError. While the caller (convert_dict_to_vllm) currently only calls this whenlen(key_value_pairs) > 1, adding a guard would make the function more robust against future misuse.🛡️ Proposed defensive check
def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: """ Merge values by taking max for amax, concatenating for others. Used for quantizer state weights (tensor values). """ + if not key_value_pairs: + raise ValueError(f"Cannot merge empty key_value_pairs for '{merged_key}'") values = [value for _, value in key_value_pairs]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/vllm_reload_utils.py` around lines 140 - 163, The function _merge_values_by_max_or_concat should defensively handle an empty key_value_pairs input to avoid IndexError on values[0]; add an early guard at the top of the function that checks if not key_value_pairs and raise a clear ValueError (or return an appropriate sentinel) with a message referencing merged_key; keep the rest of logic unchanged—this makes the helper robust even if convert_dict_to_vllm or other callers misuse it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 103-106: The tokenizer call currently hardcodes
trust_remote_code=True; change this to use a caller-configurable boolean
(default False) exposed on the worker/config that constructs the model runner
(e.g., add a parameter like trust_remote_code=False to the FakeQuantWorker
constructor or its config), then pass that value into
AutoTokenizer.from_pretrained(...) instead of True, keeping the same tokenizer
key self.model_runner.model_config.tokenizer; update any instantiation sites to
accept the new parameter or use the default.
---
Nitpick comments:
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 140-163: The function _merge_values_by_max_or_concat should
defensively handle an empty key_value_pairs input to avoid IndexError on
values[0]; add an early guard at the top of the function that checks if not
key_value_pairs and raise a clear ValueError (or return an appropriate sentinel)
with a message referencing merged_key; keep the rest of logic unchanged—this
makes the helper robust even if convert_dict_to_vllm or other callers misuse it.
In `@modelopt/torch/quantization/nn/modules/quant_module.py`:
- Around line 119-130: The parent-module device detection loop in
quant_module.py currently only checks parent_module.named_parameters() and thus
misses modules that expose only buffers; update the logic in the
restore/device-detection area (the loop that iterates
parent_module.named_parameters() and sets non_tq_param_or_buffer) to also fall
back to parent_module.named_buffers() when no suitable parameter is found,
applying the same TensorQuantizer parent check (using param_parent_name and
parent_module.get_submodule(...)) so a non-TensorQuantizer buffer can be
selected; ensure the code still prefers parameters over buffers but uses a
buffer if no parameter exists.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4d49b193-8997-4fe4-8b02-069e8fcd8419
📒 Files selected for processing (3)
examples/vllm_serve/fakequant_worker.pyexamples/vllm_serve/vllm_reload_utils.pymodelopt/torch/quantization/nn/modules/quant_module.py
| saved_quant_dict | ||
| ) | ||
| saved_quant_dict = { | ||
| key.replace("quantizer_", "quantizer._"): value |
There was a problem hiding this comment.
We do we need this?
quantizer_ -> where do we have this in HF?
There was a problem hiding this comment.
I observed this in Megatron exported checkpoint.
| model = self.model_runner.model | ||
| if hasattr(model, "unwrap"): | ||
| model = model.unwrap() | ||
| if quant_config["modelopt_state_path"]: |
There was a problem hiding this comment.
is it possible to group all restore code into a helper method. The current code organization seems a bit hard to follow to me.
There was a problem hiding this comment.
I have rearranged code, can you please review again?
c371c61 to
328fd5b
Compare
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
…on types, remove dead code Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
8ac840f to
d0f1244
Compare
| if is_quantlinear(module): | ||
| for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: | ||
| if hasattr(module, attr): | ||
| delattr(module, attr) | ||
| module.export() | ||
| torch.save(amax_dict, f"{export_dir}/quant_amax.pth") | ||
| if is_attention(module): | ||
| for attr in [ | ||
| "q_bmm_quantizer", | ||
| "k_bmm_quantizer", | ||
| "v_bmm_quantizer", | ||
| "softmax_quantizer", | ||
| ]: | ||
| if hasattr(module, attr): | ||
| delattr(module, attr) | ||
| module.export() |
There was a problem hiding this comment.
nit:
Can we simplify some thing like
for cn, cm in module.child_modules():
if isinstance(cm, TensorQuantizer):
delator(module, cn)
| modelopt_state = mto.modelopt_state(model) | ||
| modelopt_state["modelopt_state_weights"] = quantizer_state_dict | ||
| torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") | ||
| # Step 3: Remove quantizer attrs from model before saving HF weights. |
There was a problem hiding this comment.
actually this might interfere with with Nemo-RL design. In Nemo RL, during training we would export to vLLM for rollout -> This design will break that.
-> we should preserve the input models quantizer state before and after export. May be you could temporarily delete the quantizers and after export set it again. (we will have to do similarly for the weight quantizer disable step)
| all_quantizer_state_dicts = [None] * world_size | ||
| torch.distributed.gather_object(quantizer_state_dict, all_quantizer_state_dicts, dst=0) |
There was a problem hiding this comment.
can we use DistributedProcessGroup.get_dist_syncd_obj() API here ?
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
checkout model_calib.py for an example.
What does this PR do?
Type of change: new feature
Overview:
QUANT_FILE_PATHinstead of only amaxUsage
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Bug Fixes